{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c9a343cf",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T01:32:50.307641Z",
     "start_time": "2024-05-15T01:32:45.111185Z"
    }
   },
   "outputs": [],
   "source": [
    "import gym\n",
    "import numpy as np\n",
    "from IPython import display\n",
    "import matplotlib\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f7654f3d",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T01:32:52.102347Z",
     "start_time": "2024-05-15T01:32:52.095455Z"
    }
   },
   "outputs": [],
   "source": [
    "class GymHelper:\n",
    "    def __init__(self,env,figsize=(3,3)):\n",
    "        self.env=env\n",
    "        self.figsize=figsize\n",
    "        plt.figure(figsize=figsize)\n",
    "        self.img=plt.imshow(env.render())\n",
    "    def render(self,title=None):\n",
    "        img_data=self.env.render()\n",
    "        self.img.set_data(img_data)\n",
    "        display.display(plt.gcf())\n",
    "        display.clear_output(wait=True)\n",
    "        if title:\n",
    "            plt.title(title)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "dac78066",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T01:33:07.047910Z",
     "start_time": "2024-05-15T01:33:02.508462Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from tqdm import *\n",
    "import collections\n",
    "import time\n",
    "import random\n",
    "import sys\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "3a494234",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T01:34:21.428155Z",
     "start_time": "2024-05-15T01:34:06.864647Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgkAAAF2CAYAAADk/gtxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy88F64QAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA1qUlEQVR4nO3de3RU9b3H/c8kmZmEXIaEQCYDIcQQ5BKIEhBIKXdQKgLSx9KbhVWO69gKR464WqHHh3hOSzi4amuPVVcv3vrYxp5iLK1IiSKhNNpigBIuIoUAARIikMwkIUxuv+cPD1MHNkJuDAnv11rfZWbvX/Z85yc4H/fVZowxAgAAuERYqBsAAAA3JkICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAEnSa6+9phEjRigqKko2m03z58+XzWZr17a2bt0qm82mrVu3tun3Bg0apDlz5rTrPQF0vohQNwAg9D7++GPdf//9uuuuu/Tss8/K6XTK4/Hosccea9f2Ro8erffee0/Dhw/v5E4BXE+EBAD66KOP1NTUpK9//euaPHlyYPnAgQPbtb24uDiNHz++s9oDECIcbgBucosXL9bEiRMlSQsXLpTNZtOUKVOUm5t72eGGi4cDNm3apNGjRysqKkpDhw7VCy+8EDTO6nDDkSNH9OUvf1kej0dOp1NJSUmaPn26du/efVlPV9s+gOuDPQnATe7xxx/XHXfcoYceekhr1qzR1KlTFRcXp9/+9reW4//+979rxYoVeuyxx5SUlKRf/OIXWrJkiQYPHqxJkyZd8X2+8IUvqKWlRevWrdPAgQN15swZFRcXq6amplO2D6DzERKAm1x6enrg3IGMjIyrHiY4c+aM/vKXvwQORUyaNEnvvPOOfv3rX1/xS/zs2bM6ePCgfvzjH+vrX/96YPmCBQs6ZfsAugYhAUCb3HbbbUHnKkRGRmrIkCE6duzYFX8nISFB6enpevLJJ9XS0qKpU6cqKytLYWGXH/Fsz/YBdA3OSQDQJn369LlsmdPpVENDwxV/x2az6Z133tGdd96pdevWafTo0erbt6/+7d/+TbW1tR3ePoCuwZ4EANdFamqqfvnLX0r65GqK3/72t8rNzVVjY6Oef/75EHcHwAp7EgBcd0OGDNF//Md/aOTIkdq5c2eo2wFwBexJANDl9uzZo6VLl+q+++5TRkaGHA6HtmzZoj179rT7hk0Auh4hAUCXc7vdSk9P17PPPqvy8nLZbDbdcsst+uEPf6hly5aFuj0AV2AzxphQNwEAAG48nJMAAAAsERIAAIAlQgIAALAU0pDw7LPPKi0tTZGRkcrOztaf//znULYDAAA+JWQh4bXXXtPy5cv1ve99T7t27dLnP/95zZ49W8ePHw9VSwAA4FNCdnXDuHHjNHr0aD333HOBZcOGDdP8+fOVl5cXipYAAMCnhOQ+CY2NjSopKbnsJiqzZs1ScXHxVX+/tbVVp06dUmxs7GXPuwcAAFdmjFFtba08Ho/lQ9Y+LSQh4cyZM2ppaVFSUlLQ8qSkJFVWVl423u/3y+/3B16fPHky8GhbAADQduXl5RowYMBnjgnpiYuX7gUwxljuGcjLy5PL5QoUAQEAgI6JjY296piQhITExESFh4dfttegqqrqsr0LkrRy5Up5vd5AlZeXX69WAQDoka7lcH1IQoLD4VB2drYKCwuDlhcWFionJ+ey8U6nU3FxcUEFAAC6Vsge8PTII4/o/vvv15gxYzRhwgT97Gc/0/Hjx/Xggw+GqiUAAPApIQsJCxcu1NmzZ/Wf//mfqqioUGZmpjZu3KjU1NRQtQQAAD6lWz4F0ufzyeVyhboNAAC6La/Xe9XD9zy7AQAAWCIkAAAAS4QEAAB6gIgIh6KiOvdQfMhOXAQAAB3jsEcruleiMgZNUUrKbSo/vUM7PnhNxrRcNjYi3KnUAeN0S3qOCresvabtExIAAOhmoqP6KLFPulL7j1XGoGmKCHMqzpmiMOPQ4YRinT17NGh8n4RbdEvqBN0+9CuKDfeoUIQEAAB6lOiovhqcNlEJCalKS/6cYiM96mVPlCM8RpLkjh+p/p5R8nor1dx8QRERTt2S9jmNSL9bg/p9TjEOt5qa/Vd5l38iJAAAcAMLD3fIFZOs20b8P4qKiVX/hGzFR6XJHh6tcJtDRq1qMY2qaShTVd1+RTpdCg+PkD0iXoPTP687RixWUuxIOcKjJUlNIiQAANDtuVwejctapOSkkYpxJCkharDCbJ98dbe0+lXffFrehhM6cXaH6mrP6cjRYp08vUtJicM0d/p/K8oRr/jIW9r9/oQEAABuUJlD79aA5Gx5YrMVEeaUMUaNLXWqbTyp8/6zOlb5N52s2K2zZ4/qTM1hGdOq/u7bNXncvynG2U+uyIEden9CAgAAN6jz3lqdbzyjcJtDF5q9OttwSHUNFSo/uUuHj2+Xr/a06s9/HBjfr88QfX7st5TcO0sxjsufqtxWhAQAAG5Q5ZUfaPiwO3WgqkC1dadVefojlR54Q62mRa2tzYFxYWER6u3qry9M+y/1ixmmqIg+1/Qo6KshJAAAcIM6U/MPHTn6F1XVfKgTJ3bL31RnOW5I2nSNHHG3+kbfql72xE57f0ICAAA3sPd2/eKqY/wX6pUQlSGf/6QiIxIUHtY5X+/clhkAgG7u6Klibfvb/yjCFqUz5z9Uq8UdF9uDkAAAQDdnTKsOHt2sHX///+RvrFX1hTIZYzq8XUICAAA9QKtp1p5D63X02F/V4D+n6gtHOrxNQgIAAD3EhUav3tvzM31cVabG5nrVXDjWoT0KhAQAAHqQBn+N/vjnx+Sr+Vj+Zp/qGivbHRQICQAA9DAtLX79cdt35a2uVK3/pM43n21XUCAkAADQA/nqK1T0wY905uNjOlv/kfwtvjZvg5AAAECPZHTqzB7t3J+vqspDqvTtlr+5VmrDHgVupgQAQDd1z6BBKq+r0+4zZyzXG9OqoxV/UaTTJVtYmBz2WMVF9L/m7bMnAQCAbmjpqFH63ezZemXmTGX37fuZYw8dL9TZc2WKMFGqOr//mt+DkAAAQDd0z6BBsoeFaXh8vDJ69/7MsS2tTdqx/xX9o6xIvcKv/emQhAQAALqhOX/8o4pOndL/+9e/6rVDh646vqn5vAr/9n2dOFVyze9hM51x38brzOfzyeVyhboNAAC6La/Xq7i4uM8cw54EAABgiZAAAAAsdXpIyM3Nlc1mCyq32x1Yb4xRbm6uPB6PoqKiNGXKFO3bt6+z2wAAAB3UJXsSRowYoYqKikCVlpYG1q1bt05PPfWUnnnmGe3YsUNut1szZ85UbW1tV7QCAADaqUtCQkREhNxud6D6/t/1m8YY/fjHP9b3vvc9LViwQJmZmXr55Zd1/vx5/frXv+6KVgAAQDt1SUg4dOiQPB6P0tLS9OUvf1lHjnzyTOuysjJVVlZq1qxZgbFOp1OTJ09WcXHxFbfn9/vl8/mCCgAAdK1ODwnjxo3TK6+8oj/96U/6+c9/rsrKSuXk5Ojs2bOqrKyUJCUlBd/IISkpKbDOSl5enlwuV6BSUlI6u20AAHCJLr9PQn19vdLT0/Wd73xH48eP1+c+9zmdOnVKycnJgTEPPPCAysvLtWnTJstt+P1++f3+wGufz0dQAACgA26I+yRER0dr5MiROnToUOAqh0v3GlRVVV22d+HTnE6n4uLiggoAAHStLg8Jfr9fBw4cUHJystLS0uR2u1VYWBhY39jYqKKiIuXk5HR1KwAAoA06/VHRjz76qO655x4NHDhQVVVV+v73vy+fz6dFixbJZrNp+fLlWrNmjTIyMpSRkaE1a9aoV69e+upXv9rZrQAAgA7o9JBw4sQJfeUrX9GZM2fUt29fjR8/Xu+//75SU1MlSd/5znfU0NCgb3/726qurta4ceO0efNmxcbGdnYrAACgA3jAEwAAN6Eb4sRFAADQPRESAACAJUICAACwREgAAKCThNlsGpqYqISoqFC30ikICQAAdJIJKSmaOHCgJg8apDinM9TtdBghAQCATpIUHS1JcjmdcoaHh7ibjiMkAADQSf5w8KDKvV4VHT2qj8+fD3U7HdbpN1MCAOBm1dTaqj8dPhzqNjoNexIAAIAlQgIAALBESAAAAJYICQAAwBIhAQAAWCIkAAAAS4QEAABgiZAAAAAsERIAAIAlQgIAALBESAAAAJYICQAAwBIhAQAAWCIkAAAAS4QEAABgiZAAAAAsERIAAD1ebK9eGuh2KzyMr722iAh1AwAAdKVIh0NDBw1Sr6goxfbqpX1HjoS6pW6DSAUA6NHCwsJkt9slSU6HI8TddC/sSQAA9GjnL1zQ/iNHlJSQoH+cOBHqdrqVNu9J2LZtm+655x55PB7ZbDa98cYbQeuNMcrNzZXH41FUVJSmTJmiffv2BY3x+/1atmyZEhMTFR0drblz5+oE/+IAAF2kprZWB48dU0tLS6hb6VbaHBLq6+uVlZWlZ555xnL9unXr9NRTT+mZZ57Rjh075Ha7NXPmTNXW1gbGLF++XAUFBcrPz9f27dtVV1enOXPm8C8PAIAbiekASaagoCDwurW11bjdbrN27drAsgsXLhiXy2Wef/55Y4wxNTU1xm63m/z8/MCYkydPmrCwMLNp06Zrel+v12skURRFURTVzvJ6vVf9vu3UExfLyspUWVmpWbNmBZY5nU5NnjxZxcXFkqSSkhI1NTUFjfF4PMrMzAyMuZTf75fP5wsqAADQtTo1JFRWVkqSkpKSgpYnJSUF1lVWVsrhcCg+Pv6KYy6Vl5cnl8sVqJSUlM5sGwAAWOiSSyBtNlvQa2PMZcsu9VljVq5cKa/XG6jy8vJO6xUAAFjr1JDgdrsl6bI9AlVVVYG9C263W42Njaqurr7imEs5nU7FxcUFFQAA6FqdGhLS0tLkdrtVWFgYWNbY2KiioiLl5ORIkrKzs2W324PGVFRUaO/evYExAAAg9Np8M6W6ujr94x//CLwuKyvT7t27lZCQoIEDB2r58uVas2aNMjIylJGRoTVr1qhXr1766le/KklyuVxasmSJVqxYoT59+ighIUGPPvqoRo4cqRkzZnTeJwMAAB1zTdccfsq7775reSnFokWLjDGfXAa5evVq43a7jdPpNJMmTTKlpaVB22hoaDBLly41CQkJJioqysyZM8ccP378mnvgEkiKoiiK6lhdyyWQNmOMUTfj8/nkcrlC3QYAAN2W1+u96jl+POAJAABYIiQAAABLhAQAAGCJkAAAACwREgAAgCVCAgAAsERIAAAAlggJAADAEiEBAABYIiQAAABLhAQAAGCJkAAAACwREgAAgCVCAgAAsERIAAAAlggJAADAEiEBAABYIiQAAABLhAQAAGCJkAAAACwREgAAgCVCAgAAsERIAAAAlggJAADAEiEBAABYIiQAAABLhAQAAGCpzSFh27Ztuueee+TxeGSz2fTGG28ErV+8eLFsNltQjR8/PmiM3+/XsmXLlJiYqOjoaM2dO1cnTpzo0AcBAACdq80hob6+XllZWXrmmWeuOOauu+5SRUVFoDZu3Bi0fvny5SooKFB+fr62b9+uuro6zZkzRy0tLW3/BAAAoGuYDpBkCgoKgpYtWrTIzJs374q/U1NTY+x2u8nPzw8sO3nypAkLCzObNm26pvf1er1GEkVRFEVR7Syv13vV79suOSdh69at6tevn4YMGaIHHnhAVVVVgXUlJSVqamrSrFmzAss8Ho8yMzNVXFzcFe0AAIB2iOjsDc6ePVv33XefUlNTVVZWpscff1zTpk1TSUmJnE6nKisr5XA4FB8fH/R7SUlJqqystNym3++X3+8PvPb5fJ3dNgAAuESnh4SFCxcGfs7MzNSYMWOUmpqqN998UwsWLLji7xljZLPZLNfl5eXpiSee6OxWAQDAZ+jySyCTk5OVmpqqQ4cOSZLcbrcaGxtVXV0dNK6qqkpJSUmW21i5cqW8Xm+gysvLu7ptAABuel0eEs6ePavy8nIlJydLkrKzs2W321VYWBgYU1FRob179yonJ8dyG06nU3FxcUEFAAC6VpsPN9TV1ekf//hH4HVZWZl2796thIQEJSQkKDc3V1/84heVnJyso0ePatWqVUpMTNS9994rSXK5XFqyZIlWrFihPn36KCEhQY8++qhGjhypGTNmdN4nAwAAHXNN1xx+yrvvvmt5KcWiRYvM+fPnzaxZs0zfvn2N3W43AwcONIsWLTLHjx8P2kZDQ4NZunSpSUhIMFFRUWbOnDmXjeESSIqiKIrqurqWSyBtxhijbsbn88nlcoW6DQAAui2v13vVw/c8uwEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALLUpJOTl5Wns2LGKjY1Vv379NH/+fB08eDBojDFGubm58ng8ioqK0pQpU7Rv376gMX6/X8uWLVNiYqKio6M1d+5cnThxouOfBgAAdJo2hYSioiI99NBDev/991VYWKjm5mbNmjVL9fX1gTHr1q3TU089pWeeeUY7duyQ2+3WzJkzVVtbGxizfPlyFRQUKD8/X9u3b1ddXZ3mzJmjlpaWzvtkAACgY0wHVFVVGUmmqKjIGGNMa2urcbvdZu3atYExFy5cMC6Xyzz//PPGGGNqamqM3W43+fn5gTEnT540YWFhZtOmTdf0vl6v10iiKIqiKKqd5fV6r/p926FzErxeryQpISFBklRWVqbKykrNmjUrMMbpdGry5MkqLi6WJJWUlKipqSlojMfjUWZmZmAMAAAIvYj2/qIxRo888ogmTpyozMxMSVJlZaUkKSkpKWhsUlKSjh07FhjjcDgUHx9/2ZiLv38pv98vv98feO3z+drbNgAAuEbt3pOwdOlS7dmzR7/5zW8uW2ez2YJeG2MuW3apzxqTl5cnl8sVqJSUlPa2DQAArlG7QsKyZcu0YcMGvfvuuxowYEBgudvtlqTL9ghUVVUF9i643W41Njaqurr6imMutXLlSnm93kCVl5e3p20AANAGbQoJxhgtXbpUr7/+urZs2aK0tLSg9WlpaXK73SosLAwsa2xsVFFRkXJyciRJ2dnZstvtQWMqKiq0d+/ewJhLOZ1OxcXFBRUAAOhi13olgzHGfOtb3zIul8ts3brVVFRUBOr8+fOBMWvXrjUul8u8/vrrprS01HzlK18xycnJxufzBcY8+OCDZsCAAebtt982O3fuNNOmTTNZWVmmubmZqxsoiqIo6jrUtVzd0KaQcKU3evHFFwNjWltbzerVq43b7TZOp9NMmjTJlJaWBm2noaHBLF261CQkJJioqCgzZ84cc/z48Wvug5BAURRFUR2rawkJtv/78u9WfD6fXC5XqNsAAKDb8nq9Vz18z7MbAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMBSm0JCXl6exo4dq9jYWPXr10/z58/XwYMHg8YsXrxYNpstqMaPHx80xu/3a9myZUpMTFR0dLTmzp2rEydOdPzTAACATtOmkFBUVKSHHnpI77//vgoLC9Xc3KxZs2apvr4+aNxdd92lioqKQG3cuDFo/fLly1VQUKD8/Hxt375ddXV1mjNnjlpaWjr+iQAAQOcwHVBVVWUkmaKiosCyRYsWmXnz5l3xd2pqaozdbjf5+fmBZSdPnjRhYWFm06ZN1/S+Xq/XSKKoHl+Rkf8shyP0/dwMxZxTN0t5vd6rft9GqAO8Xq8kKSEhIWj51q1b1a9fP/Xu3VuTJ0/WD37wA/Xr10+SVFJSoqamJs2aNSsw3uPxKDMzU8XFxbrzzjsvex+/3y+/3x947fP5OtI20C3Y7dLbb0sREZIx0pEj0pNPfrLOGKmmRjp2LKQt9jjMORCs3SHBGKNHHnlEEydOVGZmZmD57Nmzdd999yk1NVVlZWV6/PHHNW3aNJWUlMjpdKqyslIOh0Px8fFB20tKSlJlZaXle+Xl5emJJ55ob6tAtxUe/skXliTdeqv0859/8nNrq3T48CdfaJLU3Czt3y998EFo+uxJmHPgn9odEpYuXao9e/Zo+/btQcsXLlwY+DkzM1NjxoxRamqq3nzzTS1YsOCK2zPGyGazWa5buXKlHnnkkcBrn8+nlJSU9rYOdFsX/4qEh0tDhkgZGZ+8bm2VKiul48c/+T9ev1/6wx+kbdtC12tPwZzjZtaukLBs2TJt2LBB27Zt04ABAz5zbHJyslJTU3Xo0CFJktvtVmNjo6qrq4P2JlRVVSknJ8dyG06nU06nsz2tAj3ap7/A+veXPJ5/rpswQWps/OQLrL5e+tnPpLfekjg/uGOYc9xM2nR1gzFGS5cu1euvv64tW7YoLS3tqr9z9uxZlZeXKzk5WZKUnZ0tu92uwsLCwJiKigrt3bv3iiEBwLWx2f5ZkZFSbKwUFyclJkpf+IIUExPqDnse5hw9WZv2JDz00EP69a9/rd///veKjY0NnEPgcrkUFRWluro65ebm6otf/KKSk5N19OhRrVq1SomJibr33nsDY5csWaIVK1aoT58+SkhI0KOPPqqRI0dqxowZnf8JgZuIMZ/8s7FReucdaceOT143NUm7dkn/d64xOhFzjp6sTSHhueeekyRNmTIlaPmLL76oxYsXKzw8XKWlpXrllVdUU1Oj5ORkTZ06Va+99ppiY2MD43/0ox8pIiJCX/rSl9TQ0KDp06frpZdeUnh4eMc/EdCDXfxCkj45ce7ChU9+bmqSNm+WXnvtn+NqaqS6uuveYo/DnONmZjPm038FugefzyeXyxXqNoAuZbd/chLcxTPta2ulAwc++TJqaZF275Zeeumf47vf3+QbD3OOm4nX61VcXNxnjunQfRIAdB2bLVytrYv1/PO/lDFSVZX0xz+GuquejTkHghESgBuUzRahlpZF+uUvfxnqVm4azDkQjKdAAgAAS+xJQBCbzaZevXrJ5XKpd+/e6t27t1wul1wulxwOh/bs2aMPP/xQFy6evQUA6LEICTepXr16ye12Kzk5WR6PR263O/C6d+/eiouLC1RsbKxiY2Nlt9t14MAB7dmzRxs3btTbb7+tc+fOhfqjAAC6CCGhBwgLC1NYWJgiIiIUFham8PBwhYeHKyoqSh6PR6mpqRo0aJAGDhyogQMHatCgQUpMTFRERITsdrvsdnvg54vbuJLbb79do0aN0vz583Xy5Elt3LhRr7zyisrKytTQ0MDjvgGgByEkdBMRERHq1auXYmJiFB0drejoaMXExCg2NlaJiYnq37+/UlJS5PF45PF4NGDAAPXr1++K95640nMyrkV4eLhcLpfi4uI0bNgw/fu//7u2bNmigoIC7dixQ4cPH1Z1dXW7tw8AuDEQEm4g4eHh6t27t/r166fExEQlJiaqb9++6tu3r/r06aP4+PjAXSov/ty7d++QPdfiYtAIDw/XzJkzNX36dB05ckR/+ctf9Oc//1nbtm3T4cOH1draGpL+AAAdQ0i4Tmw2m2w2m8LCwhQfH6/U1FQNHDhQAwYMCBwGcLvdiomJUVRUlKKiohQZGanIyEhFRUXJbrd36P/+r4ewsDANHjxY6enpmjdvnk6dOqX3339fv/rVr/Tee++psbFR3fDeXQBw0yIkdKGYmJjASYDZ2dmaPn26Jk2apOTk5KDQcPHni9Xd2Wy2wJURQ4cO1de+9jUdPnxY+fn5KigoUEVFhWpqajh/AQBucISETmSz2TRo0CClp6crLS1NWVlZuu222zRy5Mir3vqypwoLC5PT6dTw4cP1xBNPaPny5XrnnXe0bds2/fWvf9WePXvk9/tD3SYAwAIhoYPsdrvGjh2rO+64Q2PHjtWgQYMCJw86HI5Qt3dDsdlsSkhI0H333ac5c+boyJEj2r17tzZt2qQNGzbI5/OFukUAwKcQEq6RzWYLnCOQkJCgKVOmaMqUKZo0aZJiY2MVGRkpp9P5mZcP4p+ioqI0YsQIDRs2TPPmzdO6dev02muv6Te/+Y0OHjyo+vp6NTc3h7pNALipERI+Q2RkpDwej/r3769bbrlFd9xxh8aPH6/bb789MKYnnEMQSmFhYYqJiVFMTIwefvhhPfjgg/rggw+0fv16vffeezp06JCqq6s54REAQoCQcIk+ffpo1KhRuv322zVs2DClp6crIyND/fv3JxB0sYt7ayZOnKgJEybo6NGjKikp0fbt27V582YdPHgw1C0CwE3lpg0JF68sCA8PV1pamqZOnapp06Zp2LBh6t27t+Lj4xUdHR3qNm9a4eHhSk9PV3p6umbPnq2lS5eqpKREr776qrZs2SK/38/9FwCgi91UISEmJiZwU6LMzExNmzZNkydPVkpKSuDWxuwtuPHExsYqJiZGgwcP1n333aeysjK9+uqrKiws1JEjR3T69GkORwBAF+jxISE1NVVDhgxRenq6RowYodtuu00jRoxQfHx8qFtDG3z6vhIZGRnKzc3VsmXLVFxcrLfffls7duzQ7t271dDQEOpWAaDH6HEhISwsTGPHjtX48eM1bty4wCWJSUlJioyMDHV76ER9+vTRPffco1mzZuno0aP68MMP9fvf/15vvvmmqqqqQt0eAHR73TokREVFqVevXnK5XJo6dWrgkkSXyyWn0ymHw8EliTcBp9OpW2+9VUOGDNHMmTO1evVqbd68WS+//LIOHDig2tpaNTU1hbpNAOh2unVI+K//+i99/vOf1+jRoxUR0a0/CjqBzWZTr169lJqaqgceeEDf/OY39d5772nDhg3as2ePjh8/rhMnTqi2tjbUrQJAt9Ctv1kfeOCBm/Z2x7i68PBwTZw4URMnTtSZM2d0+PBhHTlyRB9++KF27dqlnTt36uTJk6FuEwBuWN06JADX6uKjt8eNGye/3y+fzyev16sjR45o27ZtKiwsVGlpqVpaWtTc3MzllQAgQgJuQk6nU3379lXfvn2Vnp6uGTNm6IknntDHH3+s4uJibd26VTt27NDZs2d19uxZnTt3LtQtA0BIEBJwU/v047ndbrcWLFigBQsWqL6+Xh9++KH279+v0tJSffTRRzp48KA++ugj9jIAuGkQEgAL0dHRys7OVnZ2tpqbm1VVVaWqqiqdPHlS77//voqLi/Xee+9xXwYAPRohAbiKiIiIwOO/s7KyNHPmTDU1Ncnv92v79u3aunWr3n33XZ0+fVr19fWqr69XS0tLqNsGgA4jJABtYLPZ5HA45HA4FB0drblz52ru3LkyxujAgQPasWOHdu7cqUOHDuno0aM6fvy46uvrQ902ALRLm+409Nxzz2nUqFGKi4tTXFycJkyYoLfeeiuw3hij3NxceTweRUVFacqUKdq3b1/QNvx+v5YtW6bExMTAf2RPnDjROZ8GCBGbzabhw4dr0aJFevrpp/XSSy/pxRdf1AsvvKDc3FzdfffdcrvdoW4TANqkTXsSBgwYoLVr12rw4MGSpJdfflnz5s3Trl27NGLECK1bt05PPfWUXnrpJQ0ZMkTf//73NXPmTB08eFCxsbGSpOXLl+sPf/iD8vPz1adPH61YsUJz5sxRSUmJwsPDO/8TAiHQr18/9evXL3DJZW1trerq6vTRRx9py5Ytevvtt3Xw4EE1NzerqamJwxM3kYsny376eSRXeu1wOBQfH6+EhAQlJCQoPj7+iq/j4+O1efNm/eAHP5DP51Nzc3OoPyp6AJvp4OPzEhIS9OSTT+qb3/ymPB6Pli9fru9+97uSPtlrkJSUpP/+7//Wv/7rv8rr9apv37761a9+pYULF0qSTp06pZSUFG3cuFF33nnnNb2nz+eTy+WS1+vlZkroVi7+dTPGqLKyUsXFxSoqKtLf//53ffzxxzp9+rSqq6slfXKpZmFhoSZNmhTKlm8q7ZnzsLAwORwOOZ1O2e32wOGoS+viut69e6t3795yuVyBx9Jf/PnT66weV38tT6mtqanRU089pQ0bNmj//v2EBVzRtXyHtjsktLS06H//93+1aNEi7dq1S5GRkUpPT9fOnTt1++23B8bNmzdPvXv31ssvv6wtW7Zo+vTpOnfuXNBTGLOysjR//nw98cQTlu/l9/vl9/sDr30+n1JSUggJ6DF8Pp8++ugj7d+/X/v379eHH36oQ4cO6Y477tBLL70U6vZuGuHh4frGN76h3/3ud+rVq5diYmIUHR0dqIuvrdZdXH6119freTIffvih1q9fr9/+9rfas2fPdXlPdC/X8h3a5hMXS0tLNWHCBF24cEExMTEqKCjQ8OHDVVxcLElKSkoKGp+UlKRjx45JkiorKwO7zy4dU1lZecX3zMvLu2KAAHqCuLg4jRkzRmPGjFFTU5POnDmjjz/+WOfOndM3vvGNULd3U7HZbPqXf/mXwP/5X+mfF3++UQ+TDh06VI899pjuvfdebdq0SevWrVNVVZU6uPMYN5k2h4Rbb71Vu3fvVk1NjdavX69FixapqKgosP7S3WHGmKvuIrvamJUrV+qRRx4JvL64JwHoiex2u5KTk5WcnBzqVtDNhYeHa/jw4RoyZIgeeOAB/fCHP9TLL7+sU6dOqbGxMdTtoRto834vh8OhwYMHa8yYMcrLy1NWVpaefvrpwJnbl+4RqKqqCuxdcLvdamxsDBxztRpjxel0Bq6ouFgAgGsTERGh2NhY5ebmasuWLVq1apXGjBkT6rbQDXT44JgxRn6/X2lpaXK73SosLAysa2xsVFFRkXJyciRJ2dnZstvtQWMqKiq0d+/ewBgAQNdJS0vT6tWr9cILL2jdunWBq9UAK2063LBq1SrNnj1bKSkpqq2tVX5+vrZu3apNmzbJZrNp+fLlWrNmjTIyMpSRkaE1a9aoV69e+upXvypJcrlcWrJkiVasWKE+ffooISFBjz76qEaOHKkZM2Z0yQcEAFxu5MiRGjp0qL7+9a/r1Vdf1dq1a+Xz+dTU1BTq1nADaVNIOH36tO6//35VVFTI5XJp1KhR2rRpk2bOnClJ+s53vqOGhgZ9+9vfVnV1tcaNG6fNmzcH7pEgST/60Y8UERGhL33pS2poaND06dP10ksv3bAn/wBAT3Xx/JcVK1Zo8eLFevrpp/Xmm29qz5493LsDkjrhPgmhwH0SAKBrHDhwQAUFBcrPz1dpaWmo20EX6tL7JIQSIQEAuk5zc7OOHDmiP/7xj1q7dq3OnDnDpZM9ECEBANBuLS0tOn/+vJ588kn95je/0fHjx7l0sgchJAAAOsXhw4f16quvasOGDSopKQl1O+gEhAQAQKcxxmjv3r3avHmzfvrTn6qsrCzULV13cXFxGjFihEaPHq1Ro0apV69e8nq9qqmpkc/nU01Njbxer7xe72Wv6+rqQt1+EEICAKDTNTU1qaamRi+88IJ++MMfqqampkddOhkWFian0ymn06no6GgNHTpUEydOVE5OjoYOHaqYmBhFRkbK6XTKZrOptbU1UC0tLVd8ffFmgtXV1Tp37pzOnj2rc+fOqbq6Oujnc+fO6dy5c/L7/UHbutJ7tLa2tuucEUICAKDLGGNUVVWlZ555Rhs3btSuXbu65QmO4eHhio+Pl9vtVlJSkgYNGqSsrCxlZ2drxIgRl33PXMvTOK20ZW6MMWpoaAjspbi4N+JKr+vq6uT3+3XhwgX5/X41NjYGfr60Lly4IImQAAC4Tvbv36/f//73evXVV7Vv375Qt3NViYmJGjp0aODZFqmpqUpLS1NqaqoSEhKu29M6O0tzc7MaGhp0/vz5wD8v1qWvq6ur9d3vfpeQAAC4fpqamnTs2DFt2LBBa9asUXV1tVpbW0PWj81mU1hYmMLCwhQXF6dRo0YpJydH48aNU1pamuLi4uRyuRQdHa2IiDY/77Dbast3KCEBANCpWltbVV9fr3Xr1ul3v/udjhw5cl0unYyIiFBcXJx69+6thIQEDRkyRGPHjlVOTo6GDRumqKioQHBo7yGDnoCQAAC4IRw6dEj5+fkqKCjQrl27OnXb4eHhSkpKUmpqqtLT03XLLbcoIyNDQ4cOVUZGhlwuV6e+X09BSAAA3DCMMdq3b5/+9Kc/6Sc/+YmOHz/eru3YbDb16dNHt99+e6AGDBigvn37KikpSS6X66beQ3CtCAkAgBtOU1OTamtr9fzzz+t//ud/dO7cOcvDEBEREYqMjFRkZKR69+6t22+/XePGjdPnPvc5ZWRkyOFwyOFwyG63d7sTDG8EhAQAwA3t5MmTev755/XWW29pz5496tevn9xutzwejzIyMjRy5EhlZWVpxIgRcjgcoW63RyEkAAC6hb179+qvf/2r+vfvrwEDBmjgwIH8d72LteU79Oa55gMAcMPJzMxUZmZmqNvAFXAwBwAAWCIkAAAAS4QEAABgiZAAAAAsERIAAIAlQgIAALBESAAAAJYICQAAwBIhAQAAWCIkAAAAS4QEAABgiZAAAAAstSkkPPfccxo1apTi4uIUFxenCRMm6K233gqsX7x4sWw2W1CNHz8+aBt+v1/Lli1TYmKioqOjNXfuXJ04caJzPg0AAOg0bQoJAwYM0Nq1a/XBBx/ogw8+0LRp0zRv3jzt27cvMOauu+5SRUVFoDZu3Bi0jeXLl6ugoED5+fnavn276urqNGfOHLW0tHTOJwIAAJ3CZowxHdlAQkKCnnzySS1ZskSLFy9WTU2N3njjDcuxXq9Xffv21a9+9SstXLhQknTq1CmlpKRo48aNuvPOO6/pPdvyLGwAAPBPbfkObfc5CS0tLcrPz1d9fb0mTJgQWL5161b169dPQ4YM0QMPPKCqqqrAupKSEjU1NWnWrFmBZR6PR5mZmSouLm5vKwAAoAtEtPUXSktLNWHCBF24cEExMTEqKCjQ8OHDJUmzZ8/Wfffdp9TUVJWVlenxxx/XtGnTVFJSIqfTqcrKSjkcDsXHxwdtMykpSZWVlVd8T7/fL7/fH3jt8/na2jYAAGijNoeEW2+9Vbt371ZNTY3Wr1+vRYsWqaioSMOHDw8cQpCkzMxMjRkzRqmpqXrzzTe1YMGCK27TGCObzXbF9Xl5eXriiSfa2ioAAOiANh9ucDgcGjx4sMaMGaO8vDxlZWXp6aefthybnJys1NRUHTp0SJLkdrvV2Nio6urqoHFVVVVKSkq64nuuXLlSXq83UOXl5W1tGwAAtFGH75NgjAk6FPBpZ8+eVXl5uZKTkyVJ2dnZstvtKiwsDIypqKjQ3r17lZOTc8X3cDqdgcsuLxYAAOhabTrcsGrVKs2ePVspKSmqra1Vfn6+tm7dqk2bNqmurk65ubn64he/qOTkZB09elSrVq1SYmKi7r33XkmSy+XSkiVLtGLFCvXp00cJCQl69NFHNXLkSM2YMaNLPiAAAGifNoWE06dP6/7771dFRYVcLpdGjRqlTZs2aebMmWpoaFBpaaleeeUV1dTUKDk5WVOnTtVrr72m2NjYwDZ+9KMfKSIiQl/60pfU0NCg6dOn66WXXlJ4eHinfzgAANB+Hb5PQihwnwQAANrnutwnAQAA9GyEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMASIQEAAFgiJAAAAEuEBAAAYImQAAAALBESAACAJUICAACwREgAAACWCAkAAMBSRKgbaA9jjCTJ5/OFuBMAALqXi9+dF79LP0u3DAm1tbWSpJSUlBB3AgBA91RbWyuXy/WZY2zmWqLEDaa1tVUHDx7U8OHDVV5erri4uFC31C34fD6lpKQwZ23EvLUdc9Y+zFvbMWdtZ4xRbW2tPB6PwsI++6yDbrknISwsTP3795ckxcXF8QejjZiz9mHe2o45ax/mre2Ys7a52h6EizhxEQAAWCIkAAAAS902JDidTq1evVpOpzPUrXQbzFn7MG9tx5y1D/PWdsxZ1+qWJy4CAICu1233JAAAgK5FSAAAAJYICQAAwBIhAQAAWOqWIeHZZ59VWlqaIiMjlZ2drT//+c+hbilktm3bpnvuuUcej0c2m01vvPFG0HpjjHJzc+XxeBQVFaUpU6Zo3759QWP8fr+WLVumxMRERUdHa+7cuTpx4sR1/BTXV15ensaOHavY2Fj169dP8+fP18GDB4PGMG+Xe+655zRq1KjATWsmTJigt956K7CeObu6vLw82Ww2LV++PLCMebtcbm6ubDZbULnd7sB65uw6Mt1Mfn6+sdvt5uc//7nZv3+/efjhh010dLQ5duxYqFsLiY0bN5rvfe97Zv369UaSKSgoCFq/du1aExsba9avX29KS0vNwoULTXJysvH5fIExDz74oOnfv78pLCw0O3fuNFOnTjVZWVmmubn5On+a6+POO+80L774otm7d6/ZvXu3ufvuu83AgQNNXV1dYAzzdrkNGzaYN9980xw8eNAcPHjQrFq1ytjtdrN3715jDHN2NX/729/MoEGDzKhRo8zDDz8cWM68XW716tVmxIgRpqKiIlBVVVWB9czZ9dPtQsIdd9xhHnzwwaBlQ4cONY899liIOrpxXBoSWltbjdvtNmvXrg0su3DhgnG5XOb55583xhhTU1Nj7Ha7yc/PD4w5efKkCQsLM5s2bbpuvYdSVVWVkWSKioqMMcxbW8THx5tf/OIXzNlV1NbWmoyMDFNYWGgmT54cCAnMm7XVq1ebrKwsy3XM2fXVrQ43NDY2qqSkRLNmzQpaPmvWLBUXF4eoqxtXWVmZKisrg+bL6XRq8uTJgfkqKSlRU1NT0BiPx6PMzMybZk69Xq8kKSEhQRLzdi1aWlqUn5+v+vp6TZgwgTm7ioceekh33323ZsyYEbScebuyQ4cOyePxKC0tTV/+8pd15MgRSczZ9datHvB05swZtbS0KCkpKWh5UlKSKisrQ9TVjevinFjN17FjxwJjHA6H4uPjLxtzM8ypMUaPPPKIJk6cqMzMTEnM22cpLS3VhAkTdOHCBcXExKigoEDDhw8P/IeXObtcfn6+du7cqR07dly2jj9r1saNG6dXXnlFQ4YM0enTp/X9739fOTk52rdvH3N2nXWrkHCRzWYLem2MuWwZ/qk983WzzOnSpUu1Z88ebd++/bJ1zNvlbr31Vu3evVs1NTVav369Fi1apKKiosB65ixYeXm5Hn74YW3evFmRkZFXHMe8BZs9e3bg55EjR2rChAlKT0/Xyy+/rPHjx0tizq6XbnW4ITExUeHh4ZclwaqqqstSJRQ4G/iz5svtdquxsVHV1dVXHNNTLVu2TBs2bNC7776rAQMGBJYzb1fmcDg0ePBgjRkzRnl5ecrKytLTTz/NnF1BSUmJqqqqlJ2drYiICEVERKioqEg/+clPFBEREfjczNtni46O1siRI3Xo0CH+rF1n3SokOBwOZWdnq7CwMGh5YWGhcnJyQtTVjSstLU1utztovhobG1VUVBSYr+zsbNnt9qAxFRUV2rt3b4+dU2OMli5dqtdff11btmxRWlpa0Hrm7doZY+T3+5mzK5g+fbpKS0u1e/fuQI0ZM0Zf+9rXtHv3bt1yyy3M2zXw+/06cOCAkpOT+bN2vYXibMmOuHgJ5C9/+Uuzf/9+s3z5chMdHW2OHj0a6tZCora21uzatcvs2rXLSDJPPfWU2bVrV+CS0LVr1xqXy2Vef/11U1paar7yla9YXio0YMAA8/bbb5udO3eaadOm9ehLhb71rW8Zl8tltm7dGnSJ1fnz5wNjmLfLrVy50mzbts2UlZWZPXv2mFWrVpmwsDCzefNmYwxzdq0+fXWDMcyblRUrVpitW7eaI0eOmPfff9/MmTPHxMbGBv47z5xdP90uJBhjzE9/+lOTmppqHA6HGT16dODStZvRu+++ayRdVosWLTLGfHK50OrVq43b7TZOp9NMmjTJlJaWBm2joaHBLF261CQkJJioqCgzZ84cc/z48RB8muvDar4kmRdffDEwhnm73De/+c3A37u+ffua6dOnBwKCMczZtbo0JDBvl7t43wO73W48Ho9ZsGCB2bdvX2A9c3b98KhoAABgqVudkwAAAK4fQgIAALBESAAAAJYICQAAwBIhAQAAWCIkAAAAS4QEAABgiZAAAAAsERIAAIAlQgIAALBESAAAAJYICQAAwNL/D6nThTQih6sOAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 600x600 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "env=gym.make(\"LunarLander-v2\",render_mode=\"rgb_array\")\n",
    "env.reset()\n",
    "gym_helper=GymHelper(env,figsize=(6,6))\n",
    "for i in range(100):\n",
    "    gym_helper.render(title=str(i))\n",
    "    action=env.action_space.sample()\n",
    "    observation,reward,terminated,truncated,info=env.step(action)\n",
    "    done=terminated or truncated\n",
    "    if done:break\n",
    "gym_helper.render(\"finish\")\n",
    "env.close()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "347acdbb",
   "metadata": {},
   "source": [
    "## SAC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "62af9336",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T02:23:44.931927Z",
     "start_time": "2024-05-15T02:23:44.926423Z"
    }
   },
   "outputs": [],
   "source": [
    "#定义策略网络\n",
    "class PolicyModel(nn.Module):\n",
    "    def __init__(self,input_dim,output_dim):\n",
    "        super(PolicyModel,self).__init__()\n",
    "        #使用全连接层构建一个简单的神经网络，共享部分网络层\n",
    "        self.fc=nn.Sequential([\n",
    "            nn.Linear(input_dim,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,output_dim),\n",
    "            nn.Softmax(dim=1)\n",
    "        ])\n",
    "    def forward(self,x):\n",
    "        action_prob=self.fc(x)\n",
    "        return action_prob"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "66cc88cd",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T01:38:24.438289Z",
     "start_time": "2024-05-15T01:38:24.432763Z"
    }
   },
   "outputs": [],
   "source": [
    "#定义Q网络模型\n",
    "class QvalueModel(nn.Module):\n",
    "    def __init__(self,input_dim,output_dim):\n",
    "        super(QvalueModel,self).__init__()\n",
    "        #使用全连接层构建一个简单的神经网络，共享部分网络层\n",
    "        self.fc=nn.Sequential([\n",
    "            nn.Linear(input_dim,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,128),\n",
    "            nn.ReLU(),\n",
    "            nn.Linear(128,output_dim),\n",
    "        ])\n",
    "    def forward(self,x):\n",
    "        value=self.fc(x)\n",
    "        return value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a2578d69",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T01:42:10.252613Z",
     "start_time": "2024-05-15T01:42:10.246079Z"
    }
   },
   "outputs": [],
   "source": [
    "class ReplayBuffer:\n",
    "    def __init__(self,max_size):\n",
    "        self.max_size=max_size\n",
    "        self.buffer=collections.deque(maxlen=self.max_size)\n",
    "    def add(self,state,action,reward,next_state,done):\n",
    "        experience=(state,action,reward,next_state,done)\n",
    "        self.buffer.append(experience)\n",
    "    def sample(self,batch_size):\n",
    "        batch=random.sample(self.buffer,batch_size)\n",
    "        state,action,reward,next_state,done=zip(*batch)\n",
    "        return state,action,reward,next_state,done\n",
    "    def __len__(self):\n",
    "        return len(self.buffer)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "be32357b",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T02:21:06.367362Z",
     "start_time": "2024-05-15T02:21:06.348285Z"
    }
   },
   "outputs": [],
   "source": [
    "class SAC:\n",
    "    def __init__(self,env,lr=0.002,gamma=0.99,rho=0.01,buffer_size=10000):\n",
    "        self.env=env\n",
    "        self.gamma=gamma\n",
    "        self.rho=rho\n",
    "        #设置一个目标熵值,取负值转化为最小化问题\n",
    "        self.target_entropy=-np.log2(env.action_space.n)\n",
    "        #判断可用设备是GPU还是CPU\n",
    "        self.device=torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "        self.actor=PolicyModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.q1=QvalueModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.q2=QvalueModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.target_q1=QvalueModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        self.target_q2=QvalueModel(env.observation_space.shape[0],env.action_space.n).to(self.device)\n",
    "        for param,target_param in zip(self.q1.parameters(),self.target_q1.parameters()):\n",
    "            target_param.data.copy_(param)\n",
    "        for param,target_param in zip(self.q2.parameters(),self.target_q2.parameters()):\n",
    "            target_param.data.copy_(param)\n",
    "        self.optimizer_actor=torch.optim.Adam(self.actor.parameters(),lr=lr)\n",
    "        self.optimizer_q1=torch.optim.Adam(self.q1.parameters(),lr=lr)\n",
    "        self.optimizer_q2=torch.optim.Adam(self.q2.parameters(),lr=lr)\n",
    "        #alpha做作为可学习参数,学习其对数值可确保alpha=exp(log_alpha)>0\n",
    "        self.log_alpha=torch.tensor([0.0],device=self.device,requires_grad=True)\n",
    "        self.optimizer_log_alpha=torch.optim.Adam([self.log_alpha],lr=lr)\n",
    "    def choose_action(self,state):\n",
    "        state=torch.FloatTensor(np.array([state])).to(self.device)\n",
    "        with torch.no_grad():\n",
    "            action_prob,_=self.ac(state)\n",
    "        c=torch.distributions.Categorical(action_prob)\n",
    "        action=c.sample()\n",
    "        return action.item()\n",
    "    def update(self,batch):\n",
    "        states,actions,rewards,next_states,dones=zip(*batch)\n",
    "        states=torch.FloatTensor(np.array(states)).to(self.device)\n",
    "        actions=torch.FloatTensor(np.array(actions)).view(-1,1).to(self.device)\n",
    "        rewards=torch.FloatTensor(np.array(rewards)).view(-1,1).to(self.device)\n",
    "        next_states=torch.FloatTensor(np.array(next_states)).to(self.device)\n",
    "        dones=torch.FloatTensor(np.array(dones)).view(-1,1).to(self.device)\n",
    "        #计算下一状态的动作概率与对数\n",
    "        next_action_prob=self.actor(next_states)\n",
    "        log_next_prob=torch.log(next_action_prob+1e-9)\n",
    "        #计算目标Q值\n",
    "        target_q1=self.target_q1(next_states)\n",
    "        target_q2=self.target_q2(next_states)\n",
    "        target_q_min=torch.min(target_q1,target_q2)\n",
    "        min_q_next_targets=next_action_prob*(target_q_min-torch.exp(self.log_alpha)*log_next_prob)\n",
    "        min_q_next_targets=torch.sum(min_q_next_targets,dim=1,keepdim=True)\n",
    "        #计算TD目标\n",
    "        td_target=rewards+(1-done)*self.gamma*min_q_next_targets\n",
    "        #计算Q网络的loss\n",
    "        q1=self.q1(states)\n",
    "        q2=self.q2(states)\n",
    "        q1_loss=F.mse_loss(q1.gather(1,actions),td_target.detach()).mean()\n",
    "        q2_loss=F.mse_loss(q2.gather(1,actions),td_target.detach()).mean()\n",
    "        #梯度清零，反向传播,更新参数\n",
    "        self.optimizer_q1.zero_grad()\n",
    "        self.optimizer_q2.zero_grad()\n",
    "        q1_loss.backend()\n",
    "        q2_loss.backend()\n",
    "        self.optimizer_q1.step()\n",
    "        self.optimizer_q2.step()\n",
    "        #actor网络更新\n",
    "        action_prob=self.actor(states)\n",
    "        log_prob=torch.log(action_prob+1e-9)\n",
    "        #计算损失\n",
    "        q1=self.q1(states)\n",
    "        q2=self.q2(states)\n",
    "        inside_term=torch.exp(self.log_alpha)*log_prob-torch.min(q1,q2)\n",
    "        actor_loss=torch.sum(action_prob*inside_term,dim=1,keepdim=True).mean()\n",
    "        #梯度清零，反向传播,更新参数\n",
    "        self.optimizer_actor.zero_grad()\n",
    "        actor_loss.backend()\n",
    "        self.optimizer_actor.step()\n",
    "        #计算alpha的loss\n",
    "        inside_term=-torch.sum(action_prob*log_prob,dim=1,keepdim=True)-self.target_entropy\n",
    "        alpha_loss=(torch.exp(self.log_prob)*inside_term.detach()).mean()\n",
    "        #梯度清零，反向传播,更新参数\n",
    "        self.optimizer_actor.zero_grad()\n",
    "        alpha_loss.backend()\n",
    "        self.optimizer_actor.step()\n",
    "        for param,target_param in zip(self.q1.parameters(),self.target_q1.parameters()):\n",
    "            target_param.data.copy_(param.data*self.rho+target_param.data*(1-self.rho))\n",
    "        for param,target_param in zip(self.q2.parameters(),self.target_q2.parameters()):\n",
    "            target_param.data.copy_(param.data*self.rho+target_param.data*(1-self.rho))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b2787441",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2024-05-15T02:23:46.040688Z",
     "start_time": "2024-05-15T02:23:45.996783Z"
    }
   },
   "outputs": [
    {
     "ename": "TypeError",
     "evalue": "list is not a Module subclass",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_12244\\3916576633.py\u001b[0m in \u001b[0;36m<cell line: 5>\u001b[1;34m()\u001b[0m\n\u001b[0;32m      3\u001b[0m \u001b[1;31m#batch_size=32\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[0magent\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mSAC\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m      6\u001b[0m \u001b[0meps_rewards\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      7\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_12244\\70837989.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, env, lr, gamma, rho, buffer_size)\u001b[0m\n\u001b[0;32m      8\u001b[0m         \u001b[1;31m#判断可用设备是GPU还是CPU\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      9\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m\"cuda\"\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mis_available\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32melse\u001b[0m \u001b[1;34m\"cpu\"\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 10\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mactor\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mPolicyModel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobservation_space\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     11\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mq1\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mQvalueModel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobservation_space\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     12\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mq2\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mQvalueModel\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mobservation_space\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;36m0\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0menv\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0maction_space\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdevice\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Local\\Temp\\ipykernel_12244\\2782240083.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, input_dim, output_dim)\u001b[0m\n\u001b[0;32m      4\u001b[0m         \u001b[0msuper\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mPolicyModel\u001b[0m\u001b[1;33m,\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m         \u001b[1;31m#使用全连接层构建一个简单的神经网络，共享部分网络层\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 6\u001b[1;33m         self.fc=nn.Sequential([\n\u001b[0m\u001b[0;32m      7\u001b[0m             \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mLinear\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0minput_dim\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;36m128\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      8\u001b[0m             \u001b[0mnn\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mReLU\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python310\\site-packages\\torch\\nn\\modules\\container.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, *args)\u001b[0m\n\u001b[0;32m     89\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     90\u001b[0m             \u001b[1;32mfor\u001b[0m \u001b[0midx\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 91\u001b[1;33m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0madd_module\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mstr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0midx\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmodule\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     92\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     93\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0m_get_item_by_idx\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0miterator\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0midx\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;33m->\u001b[0m \u001b[0mT\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32m~\\AppData\\Roaming\\Python\\Python310\\site-packages\\torch\\nn\\modules\\module.py\u001b[0m in \u001b[0;36madd_module\u001b[1;34m(self, name, module)\u001b[0m\n\u001b[0;32m    442\u001b[0m         \"\"\"\n\u001b[0;32m    443\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mModule\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mmodule\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 444\u001b[1;33m             raise TypeError(\"{} is not a Module subclass\".format(\n\u001b[0m\u001b[0;32m    445\u001b[0m                 torch.typename(module)))\n\u001b[0;32m    446\u001b[0m         \u001b[1;32melif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtorch\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_six\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mstring_classes\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mTypeError\u001b[0m: list is not a Module subclass"
     ]
    }
   ],
   "source": [
    "max_episodes=1500\n",
    "max_steps=1000\n",
    "#batch_size=32\n",
    "\n",
    "agent=SAC(env)\n",
    "eps_rewards=[]\n",
    "\n",
    "for episode in tqdm(range(max_episodes)):\n",
    "    state,_=env.reset()\n",
    "    eps_reward=0\n",
    "    buffer=[]\n",
    "    for step in range(max_steps):\n",
    "        action=agent.choose_action(state)\n",
    "        next_state,reward,terminated,truncated,info=env.step(action)\n",
    "        done=terminated or truncated\n",
    "        buffer.append((state,action,reward,next_state,done))\n",
    "        eps_reward+=reward\n",
    "#         if len(agent.replay_buffer)>batch_size:\n",
    "#             agent.update(batch_size)\n",
    "        state=next_state\n",
    "        if done:\n",
    "            break\n",
    "    agent.update(buffer)\n",
    "    eps_rewards.append(eps_reward)\n",
    "    if episode % 40==0:\n",
    "        tqdm.write(\"Episode\"+str(episode)+\":\"+str(eps_reward))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "39385e90",
   "metadata": {},
   "outputs": [],
   "source": [
    "observation,_=env.reset()\n",
    "gym_helper=GymHelper(env,figsize=(3,3))\n",
    "agent=A2C(env)\n",
    "i=0\n",
    "while 1:\n",
    "    gym_helper.render(title=str(i))\n",
    "    action=agent.choose_action(observation)\n",
    "    observation,reward,terminated,truncated,info=env.step(action)\n",
    "    done=terminated or truncated\n",
    "    i+=1\n",
    "    time.sleep(0.5)\n",
    "    if done:\n",
    "        break\n",
    "gym_helper.render(title=\"finished\")\n",
    "env.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
